Source code for hysop.operator.parameter_plotter

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import abstractmethod
from hysop.tools.htypes import to_tuple, check_instance, first_not_None
from hysop.tools.numpywrappers import npw
from hysop.tools.io_utils import IO
from hysop.core.graph.graph import op_apply
from hysop.core.graph.computational_graph import ComputationalGraphOperator
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.parameters.tensor_parameter import TensorParameter
from hysop.backend.host.host_operator import HostOperatorBase


[docs] class PlottingOperator(HostOperatorBase): """ Base operator for plotting. """
[docs] @classmethod def supports_mpi(cls): return True
def __new__( cls, name=None, dump_dir=None, update_frequency=1, save_frequency=100, axes_shape=(1,), figsize=(30, 18), visu_rank=0, fig=None, axes=None, **kwds, ): return super().__new__(cls, **kwds) def __init__( self, name=None, dump_dir=None, update_frequency=1, save_frequency=100, axes_shape=(1,), figsize=(30, 18), visu_rank=0, fig=None, axes=None, **kwds, ): import matplotlib import matplotlib.pyplot as plt check_instance(name, str) check_instance(update_frequency, int, minval=0) check_instance(save_frequency, int, minval=0) check_instance(axes_shape, tuple, minsize=1, allow_none=True) super().__init__(**kwds) if (fig is None) ^ (axes is None): msg = "figure and axes should be specified at the same time." raise RuntimeError(msg) dump_dir = first_not_None(dump_dir, IO.default_path()) imgpath = f"{dump_dir}/{name}.png" if fig is None: fig, axes = plt.subplots(*axes_shape, figsize=figsize) fig.canvas.mpl_connect("key_press_event", self.on_key_press) fig.canvas.mpl_connect("close_event", self.on_close) axes = npw.asarray(axes).reshape(axes_shape) self.fig = fig self.axes = axes self.update_frequency = update_frequency self.save_frequency = save_frequency self.imgpath = imgpath self.should_draw = visu_rank == self.mpi_params.rank self.running = True self.plt = plt self.update_ioparams = self.io_params.clone( frequency=self.update_frequency, with_last=True ) self.save_ioparams = self.io_params.clone( frequency=self.save_frequency, with_last=True )
[docs] def draw(self): if not self.running: return self.fig.canvas.draw() self.fig.show() self.plt.pause(0.001)
@op_apply def apply(self, **kwds): self._update(**kwds) self._save(**kwds) def _update(self, simulation, **kwds): if self.update_ioparams.should_dump(simulation=simulation): self.update(simulation=simulation, **kwds) if self.should_draw: self.draw() def _save(self, simulation, **kwds): if self.save_ioparams.should_dump(simulation=simulation): self.save(simulation=simulation, **kwds)
[docs] @abstractmethod def update(self, **kwds): pass
[docs] def save(self, **kwds): self.fig.savefig(self.imgpath, dpi=self.fig.dpi, bbox_inches="tight")
[docs] def on_close(self, event): self.running = False
[docs] def on_key_press(self, event): key = event.key if key == "q": self.plt.close(self.fig) self.running = False
[docs] class ParameterPlotter(PlottingOperator): """ Base operator to plot parameters during runtime. """ def __init__( self, name, parameters, alloc_size=128, fig=None, axes=None, shape=None, **kwds ): input_params = set() if (fig is not None) and (axes is not None): import matplotlib custom_axes = True axes_shape = None check_instance(parameters, dict, keys=matplotlib.axes.Axes, values=dict) for params in parameters.values(): check_instance(params, dict, keys=str, values=ScalarParameter) input_params.update(set(params.values())) else: custom_axes = False _parameters = {} if isinstance(parameters, TensorParameter): _parameters[0] = parameters elif isinstance(parameters, (list, tuple)): for i, p in enumerate(parameters): _parameters[i] = parameters elif isinstance(parameters, dict): _parameters = parameters.copy() else: raise TypeError(type(parameters)) check_instance( _parameters, dict, keys=(int, tuple, list), values=(TensorParameter, list, tuple, dict), ) parameters = {} axes_shape = (1,) * 2 for pos, params in _parameters.items(): pos = to_tuple(pos) pos = (2 - len(pos)) * (0,) + pos check_instance(pos, tuple, values=int) axes_shape = tuple(max(p0, p1 + 1) for (p0, p1) in zip(axes_shape, pos)) if isinstance(params, dict): input_params.update({p.name: p for p in params.values()}) elif isinstance(params, TensorParameter): input_params[params.name] = params params = {params.name: params} elif isinstance(params, (list, tuple)): for p in params: input_params[p.name] = p params = {p.name: p for p in params} else: raise TypeError(type(params)) check_instance(params, dict, keys=str, values=TensorParameter) _params = {} for pname, p in params.items(): if isinstance(p, ScalarParameter): _params[pname] = p else: for idx in npw.ndindex(*p.shape): _pname = pname + f"_{idx}" _p = p.view(idx) _params[_pname] = _p parameters[pos] = _params super().__init__( name=name, input_params=input_params, axes_shape=axes_shape, axes=axes, fig=fig, **kwds, ) self.custom_axes = custom_axes data = {} lines = {} times = npw.empty(shape=(alloc_size,), dtype=npw.float32) for pos, params in parameters.items(): params_data = {} params_lines = {} for pname, p in params.items(): pdata = npw.empty(shape=(alloc_size,), dtype=p.dtype) pline = self.get_axes(pos).plot([], [], label=pname)[0] params_data[p] = pdata params_lines[p] = pline data[pos] = params_data lines[pos] = params_lines self.fig.canvas.set_window_title("HySoP Parameter Plotter") self.parameters = parameters self.times = times self.data = data self.lines = lines self.alloc_size = alloc_size self.counter = 0
[docs] def get_axes(self, pos): axes = self.axes if self.custom_axes: return pos else: return axes[pos]
def __getitem__(self, i): if self.custom_axes: return self.axes[i] else: return self.axes.flatten()[i]
[docs] def update(self, simulation, **kwds): # expand memory if required if self.counter + 1 > self.times.size: times = npw.empty(shape=(2 * self.times.size,), dtype=self.times.dtype) times[: self.times.size] = self.times self.times = times for pos, params in self.data.items(): for p, pdata in params.items(): new_pdata = npw.empty(shape=(2 * pdata.size,), dtype=pdata.dtype) new_pdata[: pdata.size] = pdata params[p] = new_pdata times, data, lines = self.times, self.data, self.lines times[self.counter] = simulation.t() for pos, params in self.parameters.items(): for pname, p in params.items(): data[pos][p][self.counter] = p() lines[pos][p].set_xdata(times[: self.counter]) lines[pos][p].set_ydata(data[pos][p][: self.counter]) self.counter += 1